/*
 * LingPipe v. 4.1.0
 * Copyright (C) 2003-2011 Alias-i
 *
 * This program is licensed under the Alias-i Royalty Free License
 * Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the Alias-i
 * Royalty Free License Version 1 for more details.
 *
 * You should have received a copy of the Alias-i Royalty Free License
 * Version 1 along with this program; if not, visit
 * http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact
 * Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211,
 * +1 (718) 290-9170.
 */

package com.aliasi.cluster;

import com.aliasi.classify.PrecisionRecallEvaluation;

import com.aliasi.util.Distance;
import com.aliasi.util.Tuple;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;


/**
 * A <code>ClusterScore</code> provides a range of cluster scoring
 * metrics for reference partitions versus response partitions.
 *
 * <P>Traditional evaluation measures for pairs of parititions involve
 * comparing the equivalence relations generated by the partitions
 * pointwise.  A partition defines an equivalence relation in the
 * usual way: a pair <code>(A,B)</code> is in the equivalence if there
 * is a cluster that contains both <code>A</code> and <code>B</code>.
 * Each element is assumed to be equal to itself.  A pair
 * <code>(A,B)</code> is a true positive if it is an equivalence in
 * the reference and response clustes, a false positive if it is in
 * the response but not the reference, and so on.  The resulting
 * precision-recall statistics over the relations is reported through
 * {@link #equivalenceEvaluation()}.
 *
 * <P>The scoring used for the Message Understanding Conferences is:
 *
 * <blockquote><code>
 * mucRecall(referencePartition,responsePartition)
 * <br> = <big><big>&Sigma;</big></big><sub><sub>c in referencePartition</sub></sub>
 *       (size(c) - overlap(c,responsePartition))
 * <br> / <big><big>&Sigma;</big></big><sub><sub>c in referencePartition</sub></sub>
 *       ( size(c) - 1 )
 * </code></blockquote>
 *
 * where <code>size(c)</code> is the number of elements in the
 * cluster <code>c</code>, and <code>overlap(c,responsePartition)</code>
 * is the number of clusters in the response partition that intersect
 * the cluster <code>c</code>.  Precision is defined dually by
 * reversing the roles of reference and response, and f-measure is defined
 * as usual.   Further details and examples can be found in:
 *
 * <blockquote>
 *   Marc Vilain, John Burger, John Aberdeen, Dennis Connolly, and
 *   Lynette Hirschman.
 *   1995.
 *   <a href="http://acl.ldc.upenn.edu/M/M95/M95-1005.pdf">A model-theoretic coreference scoring scheme.</a>
 *   In <i>Proceedings fo the 6th Message Understanding Conference (MUC6)</i>.
 *   45--52.  Morgan Kaufmann.
 * </blockquote>
 *
 * <P>B-cubed cluster scoring was defined as an alternative to the MUC
 * scoring metric.  There are two variants B-cubed cluster precision, both
 * of which are weighted averages of a per-element precision score:
 *
 * <blockquote><code>
 * b3Precision(A,referencePartition,responsePartition)
 * <br> = |cluster(responsePartition,A) INTERSECT cluster(referencePartition,A)|
 * <br> / |cluster(responsePartition,A)|
 * </code></blockquote>
 *
 * where <code>cluster(partition,a)</code> is the cluster in the
 * partition <code>partition</code> containing the element <code>a</code>;
 * in other words, this is <code>A</code>'s equivalence class and contains
 * the set of all elements equivalent to <code>A</code> in the partition.
 *
 * <P>For the uniform cluster method, each cluster in the reference partition is
 * weighted equally, and each element is weighted equally within a cluster:
 *
 * <blockquote><code>
 * b3ClusterPrecision(referencePartition,responsePartition)
 * <br> = <big><big>&Sigma;</big></big><sub><sub>a</sub></sub>
 *        b3Precision(a,referencePartition,responsePartition)
 * <br> / (|referencePartition| * |cluster(referencePartition,a)|)
 * </code></blockquote>
 *
 * <P>For the uniform element method, each element <code>a</code> is weighted uniformly:
 *
 * <blockquote><code>
 * b3ElementPrecision(ReferencePartition,ResponsePartition)
 * <br> = <big><big>&Sigma;</big></big><sub><sub>a</sub></sub>
 *        b3Precision(a,referencePartition,responsePartition) / numElements
 * </code></blockquote>
 *
 * where <code>numElements</code> is the total number of elements in
 * the partitions.  For both B-cubed approaches, recall is defined
 * dually by switching the roles of reference and response, and the
 * F<sub><sub>1</sub></sub>-measure is defined in the usual way.
 *
 * <P>The B-cubed method is described in detail in:
 *
 * <blockquote>
 *   Bagga, Amit and Breck Baldwin.
 *   1998.
 *   <a href="ftp://ftp.cis.upenn.edu/pub/breck/scoring-paper.ps.gz">Algorithms
 *   for scoring coreference chains</a>.
 *   In <i>Proceedings of the First International Conference
 *   on Language Resources and Evaluation Workshop on Linguistic
 *   Coreference.</i>
 * </blockquote>
 *
 * <P>As an example, consider the following two partitions:
 *
 * <blockquote><code>
 *   reference = { {1, 2, 3, 4, 5}, {6, 7}, {8, 9, A, B, C} }
 *   <br>
 *   response = { { 1, 2, 3, 4, 5, 8, 9, A, B, C }, { 6, 7} }
 * </code></blockquote>
 *
 * which produce the following values for method calls:
 *
 * <blockquote>
 * <table border='1' cellpadding='5'>
 * <tr><td><i>Method</i></td><td><i>Result</i></td></tr>
 * <tr><td>{@link #equivalenceEvaluation()}</td>
 *     <td>PrecisionRecallEvaluation(54,0,50,40)
 *         <br>TP=54; FN=0; FP=50; TN=40</td></tr>
 * <tr><td>{@link #mucPrecision()}</td>
 *     <td>0.9</td></tr>
 * <tr><td>{@link #mucRecall()}</td>
 *     <td>1.0</td></tr>
 * <tr><td>{@link #mucF()}</td>
 *     <td>0.947</td></tr>
 * <tr><td>{@link #b3ElementPrecision()}</td>
 *     <td>0.583</td></tr>
 * <tr><td>{@link #b3ElementRecall()}</td>
 *     <td>1.0</td></tr>
 * <tr><td>{@link #b3ElementF()}</td>
 *     <td>0.737</td></tr>
 * <tr><td>{@link #b3ClusterPrecision()}</td>
 *     <td>0.75</td></tr>
 * <tr><td>{@link #b3ClusterRecall()}</td>
 *     <td>1.0</td></tr>
 * <tr><td>{@link #b3ClusterF()}</td>
 *     <td>0.857</td></tr>
 * </table>
 * </blockquote>
 *
 * <p>Note that there are additional scoring metrics within the {@link
 * Dendrogram} class for cophenetic correlation and dendrogram-specific
 * within-cluster scatter.
 *
 * @author  Bob Carpenter
 * @version 3.8
 * @since   LingPipe2.0
 * @param <E> the type of objects being clustered
 */
public class ClusterScore<E> {

    private final PrecisionRecallEvaluation mPrEval;

    private final Set<? extends Set<? extends E>> mReferencePartition;
    private final Set<? extends Set<? extends E>> mResponsePartition;

    /**
     * Construct a cluster score object from the specified reference and
     * response partitions.  A partition is a set of disjoint sets of
     * elements.  A partition defines an equivalence relation where the
     * disjoint sets represent the equivalence classes.
     *
     * <P>The reference partition is taken to represent the "truth"
     * or the "correct" answer, also known as the "gold standard".
     * The response partition is the partition to evaluate against the
     * reference partition.
     *
     * <P>If the specified partitions are not over the same sets
     * or if the equivalence classes are not disjoint, an illegal
     * argument exception is raised.
     *
     * @param referencePartition Partition of reference elements.
     * @param responsePartition Partition of response elements.
     * @throws IllegalArgumentException If the partitions are not
     * valid partitions over the same set of elements.
     */
    public ClusterScore(Set<? extends Set<? extends E>> referencePartition,
                        Set<? extends Set<? extends E>> responsePartition) {
        assertPartitionSameSets(referencePartition,responsePartition);
        mReferencePartition = referencePartition;
        mResponsePartition = responsePartition;
        mPrEval = calculateConfusionMatrix();
    }

    /**
     * Returns the precision-recall evaluation corresponding to
     * equalities in the reference and response clusterings.
     *
     * @return The precision-recall evaluation.
     */
    public PrecisionRecallEvaluation equivalenceEvaluation() {
        return mPrEval;
    }


    /**
     * Returns the precision as defined by MUC.  See the class
     * documentation above for definitions.
     *
     * @return The precision as defined by MUC.
     */
    public double mucPrecision() {
        return mucRecall(mResponsePartition,mReferencePartition);
    }

    /**
     * Returns the recall as defined by MUC.  See the class
     * documentation above for definitions.
     *
     * @return The recall as defined by MUC.
     */
    public double mucRecall() {
        return mucRecall(mReferencePartition,mResponsePartition);
    }

    /**
     * Returns the F<sub><sub>1</sub></sub> measure of the MUC recall
     * and precision. See the class
     * documentation above for definitions.
     *
     * @return The F measure as defined by MUC.
     */
    public double mucF() {
        return f(mucPrecision(),mucRecall());
    }


    /**
     * Returns the precision as defined by B<sup>3</sup> metric with
     * equal cluster weighting.  See the class documentation above for
     * definitions.
     *
     * @return The B-cubed equal cluster precision.
     */
    public double b3ClusterPrecision() {
        return b3ClusterRecall(mResponsePartition,mReferencePartition);
    }

    /**
     * Returns the recall as defined by B<sup>3</sup> metric with
     * equal cluster weighting.  See the class documentation above for
     * definitions.
     *
     * @return The B-cubed equal cluster recall.
     */
    public double b3ClusterRecall() {
        return b3ClusterRecall(mReferencePartition,mResponsePartition);
    }

    /**
     * Returns the F<sub><Sub>1</sub></sub> measure of the
     * B<sup>3</sup> precision and recall metrics with equal cluster
     * weighting.  See the class documentation above for definitions.
     *
     * @return The B-cubed equal cluster F measure.
     */
    public double b3ClusterF() {
        return f(b3ClusterPrecision(),
                 b3ClusterRecall());
    }

    /**
     * Returns the precision as defined by B<sup>3</sup> metric with
     * equal element weighting.  See the class documentation above for
     * definitions.
     *
     * @return The B-cubed equal element precision.
     */
    public double b3ElementPrecision() {
        return b3ElementRecall(mResponsePartition,mReferencePartition);
    }

    /**
     * Returns the recall as defined by B<sup>3</sup> metric with
     * equal element weighting.  See the class documentation above for
     * definitions.
     *
     * @return The B-cubed equal element recall.
     */
    public double b3ElementRecall() {
        return b3ElementRecall(mReferencePartition,mResponsePartition);
    }

    /**
     * Returns the F<sub><Sub>1</sub></sub> measure of the
     * B<sup>3</sup> precision and recall metrics with equal element
     * weighting.  See the class documentation above for definitions.
     *
     * @return The B-cubed equal element F measure.
     */
    public double b3ElementF() {
        return f(b3ElementPrecision(),
                 b3ElementRecall());
    }

    /**
     * Returns the set of true positive relations for this scoring.
     * Each relation is an instance of {@link Tuple}.  These true
     * positives will include both <code>(x,y)</code> and
     * <code>(y,x)</code> for a true positive relation between
     * <code>x</code> and <code>y</code>.
     *
     * @return The set of true positives.
     */
    public Set<Tuple<E>> truePositives() {
        Set<Tuple<E>> referenceEquivalences = toEquivalences(mReferencePartition);
        Set<Tuple<E>> responseEquivalences = toEquivalences(mResponsePartition);
        referenceEquivalences.retainAll(responseEquivalences);
        return referenceEquivalences;
    }

    /**
     * Returns the set of false positive relations for this scoring.
     * Each relation is an instance of {@link Tuple}.  The false
     * positives will include both <code>(x,y)</code> and
     * <code>(y,x)</code> for a false positive relation between
     * <code>x</code> and <code>y</code>.
     *
     * @return The set of false positives.
     */
    public Set<Tuple<E>> falsePositives() {
        Set<Tuple<E>> referenceEquivalences = toEquivalences(mReferencePartition);
        Set<Tuple<E>> responseEquivalences = toEquivalences(mResponsePartition);
        responseEquivalences.removeAll(referenceEquivalences);
        return responseEquivalences;
    }

    /**
     * Returns the set of false negative relations for this scoring.
     * Each relation is an instance of {@link Tuple}.  The false
     * negative set will include both <code>(x,y)</code> and
     * <code>(y,x)</code> for a false negative relation between
     * <code>x</code> and <code>y</code>.
     *
     * @return The set of false negatives.
     */
    public Set<Tuple<E>> falseNegatives() {
        Set<Tuple<E>> referenceEquivalences = toEquivalences(mReferencePartition);
        Set<Tuple<E>> responseEquivalences = toEquivalences(mResponsePartition);
        referenceEquivalences.removeAll(responseEquivalences);
        return referenceEquivalences;
    }


    private PrecisionRecallEvaluation calculateConfusionMatrix() {
        Set<Tuple<E>> referenceEquivalences = toEquivalences(mReferencePartition);
        Set<Tuple<E>> responseEquivalences = toEquivalences(mResponsePartition);
        long tp = 0;
        long fn = 0;
        for (Tuple<E> tuple : referenceEquivalences) {
            if (responseEquivalences.remove(tuple))
                ++tp;
            else
                ++fn;
        }
        long numElements = ClusterScore.<E>elementsOf(mReferencePartition).size();
        long totalCount = numElements * numElements;
        long fp = responseEquivalences.size();
        long tn = totalCount - tp - fn - fp;
        return new PrecisionRecallEvaluation(tp,fn,fp,tn);
    }

    /**
     * Returns a string representation of the statistics for this
     * score.  The string includes the information in all of the
     * methods of this class: b3 scores by cluster and by element,
     * muc scores, and the precision-recall evaluation based on
     * equivalence.
     *
     * @return String-based representation of this score.
     */
    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();

        sb.append("CLUSTER SCORE");
        sb.append("\nEquivalence Evaluation\n");
        sb.append(mPrEval.toString());

        sb.append("\nMUC Evaluation");
        sb.append("\n  MUC Precision = " + mucPrecision());
        sb.append("\n  MUC Recall = " + mucRecall());
        sb.append("\n  MUC F(1) = " + mucF());

        sb.append("\nB-Cubed Evaluation");
        sb.append("\n  B3 Cluster Averaged Precision = "
                  + b3ClusterPrecision());
        sb.append("\n  B3 Cluster Averaged Recall = " + b3ClusterRecall());
        sb.append("\n  B3 Cluster Averaged F(1) = " + b3ClusterF());
        sb.append("\n  B3 Element Averaged Precision = "
                  + b3ElementPrecision());
        sb.append("\n  B3 Element Averaged Recall = " + b3ElementRecall());
        sb.append("\n  B3 Element Averaged F(1) = " + b3ElementF());

        return sb.toString();
    }

    /**
     * Returns the within-cluster scatter measure for the specified
     * clustering with respect to the specified distance.  The
     * within-cluster scatter is simply the sum of the scatters for
     * each set in the clustering; see {@link #scatter(Set,Distance)}
     * for a definition of scatter.
     *
     * <blockquote><pre>
     * withinClusterScatter(clusters,distance)
     *   = <big>&Sigma;</big><sub><sub>cluster in clusters</sub></sub> scatter(cluster,distance)</pre></blockquote>
     *
     * <p>As the number of clusters increases, the within-cluster
     * scatter decreases monotonically.  Typically, this is used
     * to determine how many clusters to return, by inspecting
     * a plot of within-cluster scatter against number of clusters
     * and looking for a &quot;knee&quot; in the graph.
     *
     * @param clustering Clustering to evaluate.
     * @param distance Distance against which to evaluate.
     * @return The within-cluster scatter score.
     * @param <E> the type of objects being clustered
     */
    static public <E> double
        withinClusterScatter(Set<? extends Set<? extends E>> clustering,
                             Distance<? super E> distance) {

        double scatter = 0.0;
        for (Set<? extends E> s : clustering)
            scatter += scatter(s,distance);
        return scatter;
    }

    /**
     * Returns the scatter for the specified cluster based on the
     * specified distance.  The scatter is the sum of all of the
     * pairwise distances between elements, with each pair of elements
     * counted once.  Abusing notation to use <code>xs[i]</code> for
     * the <code>i</code>th element returned by the set's iterator,
     ** scatter is defined by:
     *
     * <blockquote><pre>
     * scatter(xs,distance)
     *   = <big>&Sigma;</big><sub><sub>i</sub></sub> <big>&Sigma;</big><sub><sub>j &lt; i</sub></sub> distance(xs[i],xs[j])</pre></blockquote>
     *
     * Note that elements are not compared to themselves.  This
     * presupposes a distance for which the distance of an element to
     * itself is zero and which is symmetric.
     *
     * @param cluster Cluster to evaluate.
     * @param distance Distance against which to evaluate.
     * @return The total scatter for the specified set.
     * @param <E> the type of objects being clustered
     */
    static public <E> double scatter(Set<? extends E> cluster,
                                     Distance<? super E> distance) {

        // required for array; want array for indexing
        @SuppressWarnings("unchecked")
        E[] elements = (E[]) cluster.toArray();
        double scatter = 0.0;
        for (int i = 0; i < elements.length; ++i)
            for (int j = i+1; j < elements.length; ++j)
                scatter += distance.distance(elements[i],elements[j]);
        return scatter;
    }


    // includes self-equivalences for completeness of counts
    Set<Tuple<E>> toEquivalences(Set<? extends Set<? extends E>> partition) {
        Set<Tuple<E>> equivalences = new HashSet<Tuple<E>>();
        for (Set<? extends E> equivalenceClass : partition) {
            // required for array
            @SuppressWarnings("unchecked")
            E[] xs = (E[]) new Object[equivalenceClass.size()];
            equivalenceClass.toArray(xs);
            for (int i = 0; i < xs.length; ++i)
                for (int j = 0; j < xs.length; ++j)
                    equivalences.add(Tuple.<E>create(xs[i],xs[j]));
        }
        return equivalences;
    }


    private static <F> double b3ElementRecall(Set<? extends Set<? extends F>> referencePartition,
                                              Set<? extends Set<? extends F>> responsePartition) {
        double score = 0.0;
        Set<F> elementsOfReference = ClusterScore.<F>elementsOf(referencePartition);
        for (Set<? extends F> referenceEqClass : referencePartition)
            for (F referenceEqClassElt : referenceEqClass)
                score += uniformElementWeight(elementsOfReference)
                    * b3Recall(referenceEqClassElt,
                               referenceEqClass,responsePartition);
        return score;
    }

    private static <F> double uniformElementWeight(Set<? extends F> elements) {
        return 1.0 / (double) elements.size();
    }

    private static <F> double uniformClusterWeight(Set<? extends F> eqClass,
                                                   Set<? extends Set<? extends F>> partition) {
        return 1.0 / ((double) (eqClass.size() * partition.size()));
    }

    private static <F> double b3ClusterRecall(Set<? extends Set<? extends F>> referencePartition,
                                              Set<? extends Set<? extends F>> responsePartition) {
        double score = 0.0;
        for (Set<? extends F> referenceEqClass : referencePartition)
            for (F referenceEqClassElt : referenceEqClass)
                score += uniformClusterWeight(referenceEqClass,referencePartition)
                    * b3Recall(referenceEqClassElt,
                               referenceEqClass,responsePartition);
        return score;
    }

    private static <F> double b3Recall(F element,
                                       Set<? extends F> referenceEqClass,
                                       Set<? extends Set<? extends F>> responsePartition) {
        Set<? extends F> responseClass = getEquivalenceClass(element,responsePartition);
        return ClusterScore.<F>recallSets(referenceEqClass,responseClass);
    }

    private static <F> double recallSets(Set<? extends F> referenceSet, Set<? extends F> responseSet) {
        if (referenceSet.size() == 0) return 1.0;
        return ((double) intersectionSize(referenceSet,responseSet))
            / (double) referenceSet.size();
    }

    private static <F> long intersectionSize(Set<? extends F> set1, Set<? extends F> set2) {
        long count = 0;
        for (F f : set1)
            if (set2.contains(f))
                ++count;
        return count;
    }

    private static <F> void assertPartitionSameSets(Set<? extends Set<? extends F>> set1,
                                                    Set<? extends Set<? extends F>> set2) {
        ClusterScore.<F>assertValidPartition(set1);
        ClusterScore.<F>assertValidPartition(set2);
        if (!elementsOf(set1).equals(elementsOf(set2))) {
            String msg = "Partitions must be of same sets.";
            throw new IllegalArgumentException(msg);
        }
    }

    private static <F> void assertValidPartition(Set<? extends Set<? extends F>> partition) {
        Set<F> eltsSoFar = new HashSet<F>();
        for (Set<? extends F> eqClass : partition) {
            for (F member : eqClass) {
                if (!eltsSoFar.add(member)) {
                    String msg = "Partitions must not contain overlapping members."
                        + " Found overlapping element=" + member;
                    throw new IllegalArgumentException(msg);
                }
            }
        }
    }


    private static <F> Set<? extends F> getEquivalenceClass(F element,
                                                            Set<? extends Set<? extends F>> partition) {
        for (Set<? extends F> equivalenceClass : partition)
            if (equivalenceClass.contains(element))
                return equivalenceClass;
        throw new IllegalArgumentException("Element must be in an equivalence class in partition.");
    }

    private static <F> Set<F> elementsOf(Set<? extends Set<? extends F>> partition) {
        Set<F> elementSet = new HashSet<F>();
        for (Set<? extends F> eqClass : partition)
            elementSet.addAll(eqClass);
        return elementSet;
    }

    private static double f(double precision,
                            double recall) {
        return 2.0 * precision * recall
            / (precision + recall);
    }

    private static <F> double mucRecall(Set<? extends Set<? extends F>> referencePartition,
                                        Set<? extends Set<? extends F>> responsePartition) {
        long numerator = 0;
        long denominator = 0;
        for (Set<? extends F> referenceEqClass : referencePartition) {
            long numPartitions = 0;
            for (Set<? extends F> responseEqClass : responsePartition) {
                if (!Collections.disjoint(referenceEqClass,responseEqClass))
                    ++numPartitions;
            }
            numerator += referenceEqClass.size() - numPartitions;
            denominator += referenceEqClass.size() - 1;
        }
        if (denominator == 0) return 1.0;
        return ((double) numerator) / (double) denominator;
    }



}
